##############################################
# FUNCTION: apply.penalty()
#
# Allows user defined penalties to be applied
# to distance stuctures created in rcbalance
# package.
#
#
# Arguments:
# 
# dist.struct :         the distance structure output by rcbalance
#
# z :                   a vector of treatment and control indicators(same as build.dist.structure) 
#
# X :                   a data frame (same as build.dist.structure) 
#
# exact = NULL :        vector of exact variable (same as build.dist.structure)
#
# var = NULL :          variable to apply penaly to
#
# type = 'scaled':      type of penalty to apply.  One of the following ('scaled', 'caliper', 'directional', 'concordance')
#
# pen.size = 0 :        size of the penalty when condition is met
#
# cal.size = 0.2 :      size of the caliper (in number of standard deviations if cal.sd = TRUE)
#
# cal.sd = TRUE :       if TRUE size of caliper is in units of standard deviations, if FALSE then caliper is in original variable units
#
# direction = 'case' :    set direction if a directional penalty is requested, one of ('case', 'control').  The penalty will be applied
#                       such that when the individual of a pair specified in the direction argument has a covariate value 1 and the opposite
#                       individual of a pair has a covariate value of 0, then the penalty is applied to the pair, which will discourage such
#                       pairs.
#
# X_nonmaha : variable names that are not part of the mahalanobis distance, but will be used for near exact matching.
##############################################

apply.penalty <- function(dist.struct, z, X, exact=NULL, var=NULL, type="scaled", pen.size=0, cal.size=0.2, cal.sd=TRUE, direction='case', own.sd=FALSE, sd.var=NULL){
        
        # Put exact group onto dataset (X)
        X$exact_group <- exact
        
        # Split cases and controls
        cases <- X[which(z == 1),]
        controls <- X[which(z == 0),]
        
        exact.cases = cases$exact_group
        exact.controls = controls$exact_group
        
        dist.struct.pen <- list(length = length(cases))
        
        if (type=='scaled'){
                for (i in 1:nrow(cases)){
                        
                        cost <- rep(NA, length(exact.controls[which(exact.controls == exact.cases[i])]))
                        
                        index=1
                        
                        for (j in which(exact.controls == exact.cases[i])){
                                cost[index] <- abs((cases[i,c(var)] - controls[j,c(var)]))*pen.size
                                index = index + 1
                        }
                        dist.struct.pen[[i]] <- cost
                }
        }
        
        if (type=='caliper' & cal.sd==TRUE & own.sd==FALSE){
                
                sd.var <- sd(X[,c(var)])
                
                for (i in 1:nrow(cases)){
                        
                        cost <- rep(NA, length(exact.controls[which(exact.controls == exact.cases[i])]))
                        index=1
                        
                        for (j in which(exact.controls == exact.cases[i])){
                                cost[index] <- (abs(cases[i,c(var)] - controls[j,c(var)]) > cal.size*sd.var)*pen.size
                                index = index + 1
                        }
                        dist.struct.pen[[i]] <- cost
                }
        }
        
        if (type=='caliper' & cal.sd==TRUE & own.sd==TRUE){
          
          
          for (i in 1:nrow(cases)){
            
            cost <- rep(NA, length(exact.controls[which(exact.controls == exact.cases[i])]))
            index=1
            
            for (j in which(exact.controls == exact.cases[i])){
              cost[index] <- (abs(cases[i,c(var)] - controls[j,c(var)]) > cal.size*sd.var)*pen.size
              index = index + 1
            }
            dist.struct.pen[[i]] <- cost
          }
        }
        
        if (type=='caliper' & cal.sd==FALSE){
                
                for (i in 1:nrow(cases)){
                        
                        cost <- rep(NA, length(exact.controls[which(exact.controls == exact.cases[i])]))
                        index=1
                        
                        for (j in which(exact.controls == exact.cases[i])){
                                cost[index] <- (abs(cases[i,c(var)] - controls[j,c(var)]) > cal.size)*pen.size
                                index = index + 1
                        }
                        dist.struct.pen[[i]] <- cost
                }
        }
        
        if (type=='directional' & direction=='case'){
                
          
                for (i in 1:nrow(cases)){
                  
                  
                        
                        cost <- rep(NA, length(exact.controls[which(exact.controls == exact.cases[i])]))
                        index=1
                        
                        for (j in which(exact.controls == exact.cases[i])){
                                cost[index] <- ((cases[i,c(var)] - controls[j,c(var)]) > 0)*pen.size
                                index = index + 1
                        }
                        dist.struct.pen[[i]] <- cost
                }
        }
        
        if (type=='directional' & direction=='control'){
          
                
                for (i in 1:nrow(cases)){
                        
                        cost <- rep(NA, length(exact.controls[which(exact.controls == exact.cases[i])]))
                        index=1
                        
                        for (j in which(exact.controls == exact.cases[i])){
                                cost[index] <- ((cases[i,c(var)] - controls[j,c(var)]) < 0)*pen.size
                                index = index + 1
                        }
                        dist.struct.pen[[i]] <- cost
                        
                }
          
        }
        
        # if (type=='concordance'){
        # 
        #         for (i in 1:nrow(cases)){
        # 
        #                 cost <- rep(NA, length(exact.controls[which(exact.controls == exact.cases[i])]))
        #                 index=1
        # 
        #                 for (j in which(exact.controls == exact.cases[i])){
        #                         cost[index] <- ((cases[i,c(var)] != controls[j,c(var)]))*pen.size
        #                         index = index + 1
        #                 }
        #                 dist.struct.pen[[i]] <- cost
        # 
        #         }
        # }
        
        
        if (type=='concordance'){
          #X_nonmaha$exact_group <- exact
          #cases_2 <- X_nonmaha[which(z==1),]
          #controls_2 <- X_nonmaha[which(z==0),]
          #exact.cases_2 <- cases_2$exact_group
          #exact.controls_2 <- controls_2$exact_group
          
          for (i in 1:nrow(cases)){
            
            cost <- rep(NA, length(exact.controls[which(exact.controls == exact.cases[i])]))
            index=1
            
            for (j in which(exact.controls == exact.cases[i])){
              cost[index] <- ((cases[i,c(var)] != controls[j,c(var)]))*pen.size
              index = index + 1
            }
            dist.struct.pen[[i]] <- cost
            
          }
        }
        

        dist.struct.new <- dist.struct
        
        for (k in 1:length(dist.struct)){  # Changed my.dist.struct to dist.struct
                dist.struct.new[[k]] = dist.struct[[k]] + dist.struct.pen[[k]]
                
        }
        
        return(dist.struct.new)
}